/*
 * Copyright 1999-2015 dangdang.com.
 * <p>
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * </p>
 */

package com.dangdang.ddframe.rdb.sharding.jdbc;

import com.dangdang.ddframe.rdb.sharding.executor.StatementExecutor;
import com.dangdang.ddframe.rdb.sharding.executor.wrapper.StatementExecutorWrapper;
import com.dangdang.ddframe.rdb.sharding.jdbc.adapter.AbstractStatementAdapter;
import com.dangdang.ddframe.rdb.sharding.merger.ResultSetFactory;
import com.dangdang.ddframe.rdb.sharding.parser.result.GeneratedKeyContext;
import com.dangdang.ddframe.rdb.sharding.parser.result.merger.MergeContext;
import com.dangdang.ddframe.rdb.sharding.router.SQLExecutionUnit;
import com.dangdang.ddframe.rdb.sharding.router.SQLRouteResult;
import com.google.common.base.Function;
import com.google.common.collect.Iterators;
import com.google.common.collect.Lists;
import com.google.common.collect.Table;
import com.google.common.collect.TreeBasedTable;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.Setter;

import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Deque;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

/**
 * 支持分片的静态语句对象.
 * 
 * @author gaohongtao
 * @author caohao
 */
public class ShardingStatement extends AbstractStatementAdapter {
    
    private static final Function<BackendStatementWrapper, Statement> TRANSFORM_FUNCTION = new Function<BackendStatementWrapper, Statement>() {
        @Override
        public Statement apply(final BackendStatementWrapper input) {
            return input.getStatement();
        }
    };
    
    @Getter(AccessLevel.PROTECTED)
    private final ShardingConnection shardingConnection;
    
    @Getter
    private final int resultSetType;
    
    @Getter
    private final int resultSetConcurrency;
    
    @Getter
    private final int resultSetHoldability;
    
    private final Deque<List<BackendStatementWrapper>> cachedRoutedStatements = Lists.newLinkedList();
    
    @Getter(AccessLevel.PROTECTED)
    @Setter(AccessLevel.PROTECTED)
    private MergeContext mergeContext;
    
    @Setter(AccessLevel.PROTECTED)
    private ResultSet currentResultSet;
    
    @Getter(AccessLevel.PROTECTED)
    @Setter(AccessLevel.PROTECTED)
    private GeneratedKeyContext generatedKeyContext;
    
    private ResultSet generatedKeyResultSet;
    
    ShardingStatement(final ShardingConnection shardingConnection) {
        this(shardingConnection, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY, ResultSet.HOLD_CURSORS_OVER_COMMIT);
    }
    
    ShardingStatement(final ShardingConnection shardingConnection, final int resultSetType, final int resultSetConcurrency) {
        this(shardingConnection, resultSetType, resultSetConcurrency, ResultSet.HOLD_CURSORS_OVER_COMMIT);
    }
    
    public ShardingStatement(final ShardingConnection shardingConnection, final int resultSetType, final int resultSetConcurrency, final int resultSetHoldability) {
        super(Statement.class);
        this.shardingConnection = shardingConnection;
        this.resultSetType = resultSetType;
        this.resultSetConcurrency = resultSetConcurrency;
        this.resultSetHoldability = resultSetHoldability;
        cachedRoutedStatements.add(new LinkedList<BackendStatementWrapper>());
        cachedRoutedStatements.add(new LinkedList<BackendStatementWrapper>());
    }
    
    @Override
    public Connection getConnection() throws SQLException {
        return shardingConnection;
    }
    
    @Override
    public ResultSet executeQuery(final String sql) throws SQLException {
        ResultSet rs;
        try {
            rs = ResultSetFactory.getResultSet(generateExecutor(sql).executeQuery(), mergeContext);
        } finally {
            clearRouteContext();
        }
        setCurrentResultSet(rs);
        return rs;
    }
    
    @Override
    public int executeUpdate(final String sql) throws SQLException {
        try {
            return generateExecutor(sql).executeUpdate();
        } finally {
            clearRouteContext();
        }
    }
    
    @Override
    public int executeUpdate(final String sql, final int autoGeneratedKeys) throws SQLException {
        try {
            return generateExecutor(sql).executeUpdate(autoGeneratedKeys);
        } finally {
            if (null != generatedKeyContext) {
                generatedKeyContext.setAutoGeneratedKeys(autoGeneratedKeys);
            }
            
            clearRouteContext();
        }
    }
    
    @Override
    public int executeUpdate(final String sql, final int[] columnIndexes) throws SQLException {
        try {
            return generateExecutor(sql).executeUpdate(columnIndexes);
        } finally {
            if (null != generatedKeyContext) {
                generatedKeyContext.setColumnIndexes(columnIndexes);
            }
            clearRouteContext();
        }
    }
    
    @Override
    public int executeUpdate(final String sql, final String[] columnNames) throws SQLException {
        try {
            return generateExecutor(sql).executeUpdate(columnNames);
        } finally {
            if (null != generatedKeyContext) {
                generatedKeyContext.setColumnNames(columnNames);
            }
            clearRouteContext();
        }
    }
    
    @Override
    public boolean execute(final String sql) throws SQLException {
        try {
            return generateExecutor(sql).execute();
        } finally {
            clearRouteContext();
        }
    }
    
    @Override
    public boolean execute(final String sql, final int autoGeneratedKeys) throws SQLException {
        try {
            return generateExecutor(sql).execute(autoGeneratedKeys);
        } finally {
            if (null != generatedKeyContext) {
                generatedKeyContext.setAutoGeneratedKeys(autoGeneratedKeys);
            }
            clearRouteContext();
        }
    }
    
    @Override
    public boolean execute(final String sql, final int[] columnIndexes) throws SQLException {
        try {
            return generateExecutor(sql).execute(columnIndexes);
        } finally {
            if (null != generatedKeyContext) {
                generatedKeyContext.setColumnIndexes(columnIndexes);
            }
            clearRouteContext();
        }
    }
    
    @Override
    public boolean execute(final String sql, final String[] columnNames) throws SQLException {
        try {
            return generateExecutor(sql).execute(columnNames);
        } finally {
            if (null != generatedKeyContext) {
                generatedKeyContext.setColumnNames(columnNames);
            }
            clearRouteContext();
        }
    }
    
    @Override
    public final ResultSet getGeneratedKeys() throws SQLException {
        if (null != generatedKeyResultSet) {
            return generatedKeyResultSet;
        }
        if (null == generatedKeyContext || generatedKeyContext.getColumnNameToIndexMap().isEmpty()) {
            Collection<? extends Statement> routedStatements = getRoutedStatements();
            if (1 == routedStatements.size()) {
                return generatedKeyResultSet = routedStatements.iterator().next().getGeneratedKeys();
            }
        }
        if (Statement.RETURN_GENERATED_KEYS != generatedKeyContext.getAutoGeneratedKeys() && null == generatedKeyContext.getColumnIndexes()
                && null == generatedKeyContext.getColumnNames()) {
            
            return generatedKeyResultSet = new GeneratedKeysResultSet();
        }
        return generatedKeyResultSet = new GeneratedKeysResultSet(generateAutoIncrementTable(), generatedKeyContext.getColumnNameToIndexMap(), this);
    }
    
    private Table<Integer, Integer, Object> generateAutoIncrementTable() {
        if (null != generatedKeyContext.getColumnIndexes()) {
            return subTable(generatedKeyContext.getColumnIndexes());
        } else if (null != generatedKeyContext.getColumnNames()) {
            List<Integer> columnIndexes = new ArrayList<>(generatedKeyContext.getColumnNames().length);
            for (String each : generatedKeyContext.getColumnNames()) {
                if (!generatedKeyContext.getColumnNameToIndexMap().containsKey(each)) {
                    continue;
                }
                columnIndexes.add(generatedKeyContext.getColumnNameToIndexMap().get(each) + 1);
            }
            int[] parameter = new int[columnIndexes.size()];
            int index = 0;
            for (Integer each : columnIndexes) {
                parameter[index++] = each;
            }
            return subTable(parameter);
        }
        return generatedKeyContext.getValueTable();
    }
    
    private Table<Integer, Integer, Object> subTable(final int[] columnIndexes) {
        Table<Integer, Integer, Object> result = TreeBasedTable.create();
        for (int each : columnIndexes) {
            for (Map.Entry<Integer, Object> eachEntry : generatedKeyContext.getValueTable().column(each - 1).entrySet()) {
                result.put(eachEntry.getKey(), each - 1, eachEntry.getValue());
            }
        }
        return result;
    }
    
    protected void clearRouteContext() throws SQLException {
        setCurrentResultSet(null);
        List<BackendStatementWrapper> firstList = cachedRoutedStatements.pollFirst();
        cachedRoutedStatements.getFirst().addAll(firstList);
        firstList.clear();
        cachedRoutedStatements.addLast(firstList);
        generatedKeyResultSet = null;
    }
    
    private StatementExecutor generateExecutor(final String sql) throws SQLException {
        StatementExecutor result = new StatementExecutor(shardingConnection.getShardingContext().getExecutorEngine());
        SQLRouteResult sqlRouteResult = shardingConnection.getShardingContext().getSqlRouteEngine().route(sql);
        generatedKeyContext = sqlRouteResult.getGeneratedKeyContext();
        mergeContext = sqlRouteResult.getMergeContext();
        for (SQLExecutionUnit each : sqlRouteResult.getExecutionUnits()) {
            Statement statement = getStatement(shardingConnection.getConnection(each.getDataSource(), sqlRouteResult.getSqlStatementType()), each.getSql());
            replayMethodsInvocation(statement);
            result.addStatement(new StatementExecutorWrapper(statement, each));
        }
        return result;
    }
    
    protected Statement getStatement(final Connection connection, final String sql) throws SQLException {
        BackendStatementWrapper statement = null;
        for (Iterator<BackendStatementWrapper> iterator = cachedRoutedStatements.getFirst().iterator(); iterator.hasNext();) {
            BackendStatementWrapper each = iterator.next();
            if (each.isBelongTo(connection, sql)) {
                statement = each;
                iterator.remove();
            }
        }
        if (null == statement) {
            statement = generateStatement(connection, sql);
        }
        cachedRoutedStatements.getLast().add(statement);
        return statement.getStatement();
    }
    
    protected BackendStatementWrapper generateStatement(final Connection connection, final String sql) throws SQLException {
        Statement result;
        if (0 == resultSetHoldability) {
            result = connection.createStatement(resultSetType, resultSetConcurrency);
        } else {
            result = connection.createStatement(resultSetType, resultSetConcurrency, resultSetHoldability);
        }
        return new BackendStatementWrapper(result);
    }
    
    @Override
    public ResultSet getResultSet() throws SQLException {
        if (null != currentResultSet) {
            return currentResultSet;
        }
        List<ResultSet> resultSets = new ArrayList<>(getRoutedStatements().size());
        if (getRoutedStatements().size() == 1) {
            currentResultSet = getRoutedStatements().iterator().next().getResultSet();
            return currentResultSet;
        }
        for (Statement each : getRoutedStatements()) {
            resultSets.add(each.getResultSet());
        }
        currentResultSet = ResultSetFactory.getResultSet(resultSets, mergeContext);
        return currentResultSet;
    }
    
    @Override
    protected void clearRouteStatements() {
        cachedRoutedStatements.getFirst().clear();
        cachedRoutedStatements.getLast().clear();
    }
    
    @Override
    public Collection<? extends Statement> getRoutedStatements() {
        return  Lists.newArrayList(Iterators.concat(Iterators.transform(cachedRoutedStatements.getFirst().iterator(), TRANSFORM_FUNCTION),
                Iterators.transform(cachedRoutedStatements.getLast().iterator(), TRANSFORM_FUNCTION)));
    }
}
